import torch
import numpy as np
from bgflow.utils import (
    remove_mean,
)
from eq_ot_flow.models import EGNN_dynamics

from path_grad_helpers import (
    device,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_data(data_path, n_particles):
    if n_particles == 13:
        data_load = np.load(
            f"{data_path}/all_data_LJ{n_particles}-1000.npy", allow_pickle=True
        )
    else:
        datas = [
            np.load(
                f"{data_path}/all_data_LJ{n_particles}-1000-part{i+1}.npy",
                allow_pickle=True,
            )
            for i in range(2)
        ]
        data_load = np.concatenate(datas)

    return torch.from_numpy(remove_mean(data_load, n_particles, 3))


def get_dynamics(n_particles):
    if n_particles == 13:
        return EGNN_dynamics(
            n_particles=n_particles,
            device=device,
            n_dimension=3,
            hidden_nf=32,
            act_fn=torch.nn.SiLU(),
            n_layers=3,
            recurrent=True,
            tanh=True,
            attention=True,
            condition_time=True,
            mode="egnn_dynamics",
            agg="sum",
        )
    return EGNN_dynamics(
        n_particles=n_particles,
        device="cuda",
        n_dimension=3,
        hidden_nf=64,
        act_fn=torch.nn.SiLU(),
        n_layers=7,
        recurrent=True,
        tanh=True,
        attention=True,
        condition_time=True,
        mode="egnn_dynamics",
        agg="sum",
    )
